import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
# Set matplotlib to use Agg backend, suitable for use in GUI-less environments
matplotlib.use('Agg')
# Set global font to support Chinese fonts
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']
matplotlib.rcParams['axes.unicode_minus'] = False  # Solve negative sign display issue
import seaborn as sns
from scipy import signal
from scipy.stats import pearsonr
from sklearn.metrics import classification_report
import argparse
import os
from tqdm import tqdm

from clipper import Clipper
from model import DiffusionEEGModel, ImageToEEGModel
from data import get_eeg_dls
import utils



def ablation_study(args):
    """Ablation study"""
    print("Starting ablation study...")
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    
    # Different configurations - only keep Concatenation Method option
    configs = [
        {'name': 'Concatenation Method', 'use_cross_attention': False, 'clip_variant': 'ViT-L/14', 'use_concat': True},
    ]
    
    results = {}
    
    for config in configs:
        print(f"Testing configuration: {config['name']}")
        
        # Load data
        _, test_dl = get_eeg_dls(
            subject=args.subject,
            data_path=args.data_path,
            batch_size=args.batch_size,
            val_batch_size=args.batch_size,
            num_workers=args.num_workers,
            seed=args.seed
        )
        
        # Initialize model
        clip_model = Clipper(clip_variant=config['clip_variant'], device=device)
        
        # Model using concatenation method
        from model_concat import DiffusionEEGModelWithConcat, ImageToEEGModelWithConcat
        diffusion_model = DiffusionEEGModelWithConcat(
            eeg_channels=args.eeg_channels,
            eeg_length=args.eeg_length,
            hidden_dim=768 if config['clip_variant'] == 'ViT-L/14' else 512,
            num_train_timesteps=args.num_train_timesteps,
            device=device
        ).to(device)
        
        model = ImageToEEGModelWithConcat(clip_model, diffusion_model).to(device)
        
        # Evaluate
        metrics = evaluate_comprehensive(model, test_dl, device)
        results[config['name']] = metrics
    
    return results


def evaluate_mse_only(model, dataloader, device):
    """Simplified version that only evaluates MSE loss"""
    model.eval()
    
    total_mse = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='MSE evaluation', leave=False):
            eeg_data, image_data = batch_data
            
            # Data preprocessing
            eeg_data = eeg_data.float().to(device)
            image_data = image_data.float().to(device)
            
            if eeg_data.dim() == 4:
                eeg_data = utils.average_eeg_trials(eeg_data)
            
            # eeg_data = utils.normalize_eeg(eeg_data)
            
            if eeg_data.dim() == 3:
                eeg_data = eeg_data.unsqueeze(1)
            
            # Generate EEG
            generated_eeg = model.diffusion_model.generate_eeg(
                model.clip_model.embed_image(image_data).float()
            )
            
            # Calculate MSE
            mse = F.mse_loss(generated_eeg, eeg_data, reduction='mean')
            total_mse += mse.item() * eeg_data.size(0)
            total_samples += eeg_data.size(0)
    
    return total_mse / total_samples


def evaluate_comprehensive(model, dataloader, device):
    """Comprehensive evaluation"""
    model.eval()
    
    all_correlations = []
    all_mse = []
    
    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Comprehensive evaluation'):
            eeg_data, image_data = batch_data
            
            # Data preprocessing
            eeg_data = eeg_data.float().to(device)
            image_data = image_data.float().to(device)
            
            if eeg_data.dim() == 4:
                eeg_data = utils.average_eeg_trials(eeg_data)
            
            # eeg_data = utils.normalize_eeg(eeg_data)
            
            if eeg_data.dim() == 3:
                eeg_data = eeg_data.unsqueeze(1)
            
            # Generate EEG
            generated_eeg = model.diffusion_model.generate_eeg(
                model.clip_model.embed_image(image_data).float()
            )
            
            # Calculate various metrics
            correlations = utils.compute_correlation(generated_eeg, eeg_data)
            mse = F.mse_loss(generated_eeg, eeg_data, reduction='none').mean(dim=(-2, -1))
            
            all_correlations.extend(correlations.cpu().numpy())
            all_mse.extend(mse.cpu().numpy())
    
    return {
        'correlation': {
            'mean': np.mean(all_correlations),
            'std': np.std(all_correlations),
            'median': np.median(all_correlations)
        },
        'mse': {
            'mean': np.mean(all_mse),
            'std': np.std(all_mse)
        }
    }


def cross_subject_evaluation(args):
    """Cross-subject evaluation: Test each subject's best model on other subjects"""
    print("Starting cross-subject evaluation...")
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    
    # Store all results
    all_results = {}
    all_correlations = {}  # New: Store correlation results
    
    # Test each subject's model
    for train_subject in range(1, 11):
        print(f"\nTesting model of subject {train_subject}...")
        
        # Check if model file exists
        model_path = os.path.join(args.model_path, f'subject{train_subject}', f'best_model_{args.loss_type}.pth')
        if not os.path.exists(model_path):
            print(f"Warning: Model file for subject {train_subject} does not exist: {model_path}")
            continue
        
        # Initialize model
        clip_model = Clipper(clip_variant=args.clip_variant, device=device)
        diffusion_model = DiffusionEEGModel(
            eeg_channels=args.eeg_channels,
            eeg_length=args.eeg_length,
            hidden_dim=args.hidden_dim,
            num_train_timesteps=args.num_train_timesteps,
            device=device
        ).to(device)
        
        model = ImageToEEGModel(clip_model, diffusion_model).to(device)
        
        # Load trained model
        try:
            state_dict = torch.load(model_path, map_location=device)
            model.load_state_dict(state_dict)
            print(f"Successfully loaded model for subject {train_subject}")
        except Exception as e:
            print(f"Failed to load model for subject {train_subject}: {e}")
            continue
        
        # Store results of this model on all subjects
        train_subject_results = {}
        train_subject_correlations = {}  # New: Store correlation results
        
        # Test on data from all subjects
        for test_subject in range(1, 11):
            print(f"  Testing on data from subject {test_subject}...")
            
            try:
                # Load test data
                _, test_dl = get_eeg_dls(
                    subject=test_subject,
                    data_path=args.data_path,
                    batch_size=args.batch_size,
                    val_batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    seed=args.seed
                )
                
                # Evaluate model - MSE
                test_mse = evaluate_mse_only(model, test_dl, device)
                train_subject_results[f'test_subject_{test_subject}'] = test_mse
                
                # New: Evaluate model - Correlation
                test_correlation = evaluate_correlation(model, test_dl, device)
                train_subject_correlations[f'test_subject_{test_subject}'] = test_correlation
                
                print(f"    Subject {test_subject} MSE: {test_mse:.6f}, Correlation: {test_correlation:.6f}")
                
            except Exception as e:
                print(f"    Failed to test subject {test_subject}: {e}")
                train_subject_results[f'test_subject_{test_subject}'] = None
                train_subject_correlations[f'test_subject_{test_subject}'] = None
        
        all_results[f'train_subject_{train_subject}'] = train_subject_results
        all_correlations[f'train_subject_{train_subject}'] = train_subject_correlations
    
    return all_results, all_correlations


def evaluate_correlation(model, dataloader, device):
    """Evaluate correlation between generated EEG and real EEG"""
    model.eval()
    
    all_correlations = []
    
    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Correlation evaluation', leave=False):
            eeg_data, image_data = batch_data
            
            # Data preprocessing
            eeg_data = eeg_data.float().to(device)
            image_data = image_data.float().to(device)
            
            if eeg_data.dim() == 4:
                eeg_data = utils.average_eeg_trials(eeg_data)
            
            # eeg_data = utils.normalize_eeg(eeg_data)
            
            if eeg_data.dim() == 3:
                eeg_data = eeg_data.unsqueeze(1)
            
            # Generate EEG
            generated_eeg = model.diffusion_model.generate_eeg(
                model.clip_model.embed_image(image_data).float()
            )
            
            # Calculate correlation
            correlations = utils.compute_correlation(generated_eeg, eeg_data)
            all_correlations.extend(correlations.cpu().numpy())
    
    # Return average correlation
    return np.mean(all_correlations)




def main():
    parser = argparse.ArgumentParser(description='Supplementary experiment analysis')
    
    # Data parameters
    parser.add_argument('--data_path', type=str, required=True, help='Dataset path')
    parser.add_argument('--subject', type=int, default=1, help='Subject number')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Data loader process count')
    
    # Model parameters
    parser.add_argument('--clip_variant', type=str, default='ViT-L/14')
    parser.add_argument('--eeg_channels', type=int, default=63)
    parser.add_argument('--eeg_length', type=int, default=250)
    parser.add_argument('--hidden_dim', type=int, default=768)
    parser.add_argument('--num_train_timesteps', type=int, default=1000)
    
    # Experiment parameters
    parser.add_argument('--model_path', type=str, required=True, help='Model path')
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--save_dir', type=str, default='./experiment_results')
    parser.add_argument('--loss_type', type=str, default='mse')
    
    # Experiment types
    parser.add_argument('--run_ablation', action='store_true', help='Run ablation experiment')
    parser.add_argument('--run_cross_subject', action='store_true', help='Run cross-subject experiment')
    parser.add_argument('--run_comprehensive', action='store_true', help='Run comprehensive evaluation')

    
    args = parser.parse_args()
    
    # Set random seed
    utils.seed_everything(args.seed)
    
    # Create save directory
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Run different experiments
    if args.run_ablation:
        ablation_results = ablation_study(args)
        # Save results
        
    if args.run_cross_subject:
        cross_subject_results, cross_subject_correlations = cross_subject_evaluation(args)
        
        # Save cross-subject evaluation results
        result_path = os.path.join(args.save_dir, 'cross_subject_evaluation_mse.txt')
        with open(result_path, 'w', encoding='utf-8') as f:
            f.write('Cross-Subject Evaluation Results (MSE Loss)\n')
            f.write('=' * 80 + '\n\n')
            
            # Save detailed results
            for train_subject, test_results in cross_subject_results.items():
                f.write(f'Model of {train_subject} performance on various subjects:\n')
                f.write('-' * 50 + '\n')
                
                for test_subject, mse_value in test_results.items():
                    test_corr = cross_subject_correlations[train_subject].get(test_subject)
                    if mse_value is not None and test_corr is not None:
                        f.write(f'  {test_subject}: MSE = {mse_value:.6f}, Correlation = {test_corr:.6f}\n')
                    else:
                        f.write(f'  {test_subject}: Test failed\n')
                f.write('\n')
            
            # Calculate and save statistics
            f.write('Statistical Analysis:\n')
            f.write('=' * 50 + '\n')
            
            # Calculate average performance for each training subject
            for train_subject, test_results in cross_subject_results.items():
                valid_mse = [mse for mse in test_results.values() if mse is not None]
                valid_corr = [corr for corr in cross_subject_correlations[train_subject].values() if corr is not None]
                
                if valid_mse and valid_corr:
                    avg_mse = np.mean(valid_mse)
                    std_mse = np.std(valid_mse)
                    avg_corr = np.mean(valid_corr)
                    std_corr = np.std(valid_corr)
                    f.write(f'{train_subject} average MSE: {avg_mse:.6f} ± {std_mse:.6f}, average correlation: {avg_corr:.6f} ± {std_corr:.6f}\n')
            
            f.write('\n')
            
            # Calculate cross-subject generalization performance matrix - MSE
            f.write('Cross-subject generalization matrix - MSE (rows for training subjects, columns for testing subjects):\n')
            f.write('-' * 70 + '\n')
            
            # Header
            f.write('Train\\Test\t')
            for test_idx in range(1, 11):
                f.write(f'Sub{test_idx}\t')
            f.write('\n')
            
            # Data rows
            for train_idx in range(1, 11):
                train_key = f'train_subject_{train_idx}'
                if train_key in cross_subject_results:
                    f.write(f'Sub{train_idx}\t\t')
                    for test_idx in range(1, 11):
                        test_key = f'test_subject_{test_idx}'
                        mse_val = cross_subject_results[train_key].get(test_key)
                        if mse_val is not None:
                            f.write(f'{mse_val:.4f}\t')
                        else:
                            f.write('N/A\t')
                    f.write('\n')
            
            # New: Calculate cross-subject correlation matrix
            f.write('\nCross-subject correlation matrix (rows for training subjects, columns for testing subjects):\n')
            f.write('-' * 70 + '\n')
            
            # Header
            f.write('Train\\Test\t')
            for test_idx in range(1, 11):
                f.write(f'Sub{test_idx}\t')
            f.write('\n')
            
            # Data rows
            for train_idx in range(1, 11):
                train_key = f'train_subject_{train_idx}'
                if train_key in cross_subject_correlations:
                    f.write(f'Sub{train_idx}\t\t')
                    for test_idx in range(1, 11):
                        test_key = f'test_subject_{test_idx}'
                        corr_val = cross_subject_correlations[train_key].get(test_key)
                        if corr_val is not None:
                            f.write(f'{corr_val:.4f}\t')
                        else:
                            f.write('N/A\t')
                    f.write('\n')
        
        print(f'Cross-subject evaluation completed, results saved in: {result_path}')
        
        # Additionally save results in simplified CSV format - MSE matrix
        csv_path = os.path.join(args.save_dir, 'cross_subject_mse_matrix.csv')
        with open(csv_path, 'w', encoding='utf-8') as f:
            # CSV header
            f.write('Train_Subject,')
            for test_idx in range(1, 11):
                f.write(f'Test_Subject_{test_idx}')
                if test_idx < 10:
                    f.write(',')
            f.write('\n')
            
            # CSV data
            for train_idx in range(1, 11):
                train_key = f'train_subject_{train_idx}'
                if train_key in cross_subject_results:
                    f.write(f'{train_idx},')
                    for test_idx in range(1, 11):
                        test_key = f'test_subject_{test_idx}'
                        mse_val = cross_subject_results[train_key].get(test_key)
                        if mse_val is not None:
                            f.write(f'{mse_val:.6f}')
                        else:
                            f.write('NaN')
                        if test_idx < 10:
                            f.write(',')
                    f.write('\n')
        
        print(f'MSE matrix saved in: {csv_path}')
        
        # New: Additionally save results in simplified CSV format - Correlation matrix
        corr_csv_path = os.path.join(args.save_dir, 'cross_subject_correlation_matrix.csv')
        with open(corr_csv_path, 'w', encoding='utf-8') as f:
            # CSV header
            f.write('Train_Subject,')
            for test_idx in range(1, 11):
                f.write(f'Test_Subject_{test_idx}')
                if test_idx < 10:
                    f.write(',')
            f.write('\n')
            
            # CSV data
            for train_idx in range(1, 11):
                train_key = f'train_subject_{train_idx}'
                if train_key in cross_subject_correlations:
                    f.write(f'{train_idx},')
                    for test_idx in range(1, 11):
                        test_key = f'test_subject_{test_idx}'
                        corr_val = cross_subject_correlations[train_key].get(test_key)
                        if corr_val is not None:
                            f.write(f'{corr_val:.6f}')
                        else:
                            f.write('NaN')
                        if test_idx < 10:
                            f.write(',')
                    f.write('\n')
        
        print(f'Correlation matrix saved in: {corr_csv_path}')
        
        # New: Draw heatmap
        try:
            # Prepare MSE data
            mse_matrix = np.zeros((10, 10))
            for train_idx in range(1, 11):
                train_key = f'train_subject_{train_idx}'
                if train_key in cross_subject_results:
                    for test_idx in range(1, 11):
                        test_key = f'test_subject_{test_idx}'
                        mse_val = cross_subject_results[train_key].get(test_key)
                        if mse_val is not None:
                            mse_matrix[train_idx-1, test_idx-1] = mse_val
                        else:
                            mse_matrix[train_idx-1, test_idx-1] = np.nan
            
            # Prepare correlation data
            corr_matrix = np.zeros((10, 10))
            for train_idx in range(1, 11):
                train_key = f'train_subject_{train_idx}'
                if train_key in cross_subject_correlations:
                    for test_idx in range(1, 11):
                        test_key = f'test_subject_{test_idx}'
                        corr_val = cross_subject_correlations[train_key].get(test_key)
                        if corr_val is not None:
                            corr_matrix[train_idx-1, test_idx-1] = corr_val
                        else:
                            corr_matrix[train_idx-1, test_idx-1] = np.nan
            
            # Draw MSE heatmap
            plt.figure(figsize=(10, 8))
            sns.heatmap(mse_matrix, annot=True, fmt=".4f", cmap="viridis", 
                       xticklabels=range(1, 11), yticklabels=range(1, 11))
            plt.title("Cross-Subject MSE Matrix", fontsize=14)
            plt.xlabel("Test Subject", fontsize=12)
            plt.ylabel("Train Subject", fontsize=12)
            plt.tight_layout()
            mse_heatmap_path = os.path.join(args.save_dir, 'cross_subject_mse_heatmap.png')
            plt.savefig(mse_heatmap_path, dpi=300, bbox_inches='tight')
            
            # Draw correlation heatmap
            plt.figure(figsize=(10, 8))
            sns.heatmap(corr_matrix, annot=True, fmt=".4f", cmap="coolwarm", 
                       xticklabels=range(1, 11), yticklabels=range(1, 11))
            plt.title("Cross-Subject Correlation Matrix", fontsize=14)
            plt.xlabel("Test Subject", fontsize=12)
            plt.ylabel("Train Subject", fontsize=12)
            plt.tight_layout()
            corr_heatmap_path = os.path.join(args.save_dir, 'cross_subject_correlation_heatmap.png')
            plt.savefig(corr_heatmap_path, dpi=300, bbox_inches='tight')
            
            print(f"Heatmaps saved in: {mse_heatmap_path} and {corr_heatmap_path}")
        except Exception as e:
            print(f"Failed to draw heatmaps: {e}")
    
    if args.run_comprehensive:
        device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        
        # Load data and model
        _, test_dl = get_eeg_dls(
            subject=args.subject,
            data_path=args.data_path,
            batch_size=args.batch_size,
            val_batch_size=args.batch_size,
            num_workers=args.num_workers,
            seed=args.seed
        )
        
        clip_model = Clipper(clip_variant=args.clip_variant, device=device)
        diffusion_model = DiffusionEEGModel(
            eeg_channels=args.eeg_channels,
            eeg_length=args.eeg_length,
            hidden_dim=args.hidden_dim,
            num_train_timesteps=args.num_train_timesteps,
            device=device
        ).to(device)
        
        model = ImageToEEGModel(clip_model, diffusion_model).to(device)
        
        # Load model
        model_path = os.path.join(args.model_path, f'best_model_{args.loss_type}.pth')
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        
        # Comprehensive evaluation
        comprehensive_results = evaluate_comprehensive(model, test_dl, device)
        
        # Save results
        result_path = os.path.join(args.save_dir, 'comprehensive_evaluation.txt')
        with open(result_path, 'w', encoding='utf-8') as f:
            f.write('Comprehensive Evaluation Results\n')
            f.write('=' * 50 + '\n\n')
            for metric, values in comprehensive_results.items():
                f.write(f'{metric}:\n')
                for key, value in values.items():
                    f.write(f'  {key}: {value:.6f}\n')
                f.write('\n')
        
        print(f'Comprehensive evaluation completed, results saved in: {result_path}')
    

if __name__ == '__main__':
    main()